-
-
Notifications
You must be signed in to change notification settings - Fork 7.9k
[V1][Kernel] Flashinfer HND KV cache layout #19280
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Conversation
Signed-off-by: nicklucche <[email protected]>
Signed-off-by: nicklucche <[email protected]>
Signed-off-by: NickLucche <[email protected]>
Signed-off-by: NickLucche <[email protected]>
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Hello @NickLucche, I'm Gemini Code Assist1! I'm currently reviewing this pull request and will post my feedback shortly. In the meantime, here's a summary to help you and other reviewers quickly get up to speed!
Summary of Changes
Hello team, gemini-code-assist here with a summary of this pull request. This PR, authored by @NickLucche, focuses on enabling the HND (Head, Number of blocks, Dimension) KV cache layout specifically for the FlashInfer attention backend in V1 of vLLM. This is a follow-up to previous work and is intended to be a prerequisite for features like heterogeneous Tensor Parallelism (TP) with a disaggregated prefill-decode setup, as the HND layout is noted to be better optimized for data transfers in such scenarios. The changes involve modifying the FlashInfer backend to dynamically determine the desired KV cache layout (either via an environment variable or the VLLM config) and applying the necessary tensor permutation before interacting with the FlashInfer kernels.
Highlights
- FlashInfer HND KV Cache Layout: Introduces support for the HND (Head, Number of blocks, Dimension) KV cache layout when using the FlashInfer attention backend in V1.
- Configurable KV Cache Layout: Allows the KV cache layout for FlashInfer to be specified via the
FLASHINFER_KV_CACHE_LAYOUT
environment variable, overriding the default or configured layout. - Tensor Permutation for Layout: Implements logic within the FlashInfer backend to permute the KV cache tensor to the desired HND layout before passing it to the FlashInfer kernels.
- Enabling Disaggregated PD: This change is highlighted as a necessary step towards enabling heterogeneous TP and optimizing data transfers for disaggregated prefill-decode setups.
Changelog
Click here to see the changelog
- vllm/distributed/kv_transfer/kv_connector/utils.py
- Removed a redundant blank line (L6).
- Added a note explaining that HND layout is used for faster transfer in disaggregated PD with NIXL (L96-97).
- Improved the check in
get_kv_connector_cache_layout
to also verify ifkv_config
is not None before defaulting to NHD (L100).
- vllm/v1/attention/backends/flashinfer.py
- Imported
functools
andos
(L6-7). - Imported
get_kv_connector_cache_layout
fromvllm.distributed.kv_transfer.kv_connector.utils
(L21-22). - Added a module-level variable
FLASHINFER_KV_CACHE_LAYOUT
to read the environment variable (L35-36). - Added a cached function
get_flashinfer_kv_cache_layout
to determine the layout, prioritizing the environment variable (L41-51). - Added a static method
get_kv_cache_stride_order
toFlashInferBackend
to return the permutation tuple based on the determined layout (L87-98). - Updated the initialization of
BatchPrefillWithPagedKVCacheWrapper
to use the layout fromget_flashinfer_kv_cache_layout()
instead of hardcoded "NHD" (L324). - Updated the initialization of
BatchDecodeWithPagedKVCacheWrapper
to use the layout fromget_flashinfer_kv_cache_layout()
instead of hardcoded "NHD" (L337). - Updated the initialization of
MultiLevelCascadeAttentionWrapper
to use the layout fromget_flashinfer_kv_cache_layout()
instead of hardcoded "NHD" (L345). - Called
FlashInferBackend.get_kv_cache_stride_order()
in theforward
method to get the required permutation (L643). - Applied
.permute(*stride_order)
to thekv_cache
tensor before passing it toprefill_wrapper.run
(L658). - Applied
.permute(*stride_order)
to thekv_cache
tensor before passing it todecode_wrapper.run
(L674).
- Imported
Using Gemini Code Assist
The full guide for Gemini Code Assist can be found on our documentation page, here are some quick tips.
Invoking Gemini
You can request assistance from Gemini at any point in your pull request via creating an issue comment (i.e. comment on the pull request page) using either /gemini <command>
or @gemini-code-assist <command>
. Below is a summary of the supported commands.
Feature | Command | Description |
---|---|---|
Code Review | /gemini review |
Performs a code review for the current pull request in its current state. |
Pull Request Summary | /gemini summary |
Provides a summary of the current pull request in its current state. |
Comment | @gemini-code-assist | Responds in comments when explicitly tagged, both in issue comments and review comments. |
Help | /gemini help |
Displays a list of available commands. |
Customization
To customize Gemini Code Assist for GitHub experience, repository maintainers can create a configuration file and/or provide a custom code review style guide (such as PEP-8 for Python) by creating and adding files to a .gemini/
folder in the base of the repository. Detailed instructions can be found here.
Limitations & Feedback
Gemini Code Assist is currently in preview and may make mistakes. Please leave feedback on any instances where its feedback is incorrect or counter productive. You can react with 👍 and 👎 on @gemini-code-assist comments to provide feedback.
You can also get AI-powered code generation, chat, as well as code reviews directly in the IDE at no cost with the Gemini Code Assist IDE Extension.
Footnotes
-
Review the Privacy Notices, Generative AI Prohibited Use Policy, Terms of Service, and learn how to configure Gemini Code Assist in GitHub here. Gemini can make mistakes, so double check it and use code with caution. ↩
👋 Hi! Thank you for contributing to the vLLM project. 💬 Join our developer Slack at https://slack.vllm.ai to discuss your PR in #pr-reviews, coordinate on features in #feat- channels, or join special interest groups in #sig- channels. Just a reminder: PRs would not trigger full CI run by default. Instead, it would only run Once the PR is approved and ready to go, your PR reviewer(s) can run CI to test the changes comprehensively before merging. To run CI, PR reviewers can either: Add 🚀 |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Code Review
This pull request successfully enables the HND KV cache layout for FlashInfer in V1, enhancing flexibility and potentially performance for disaggregated prefill-decode setups. The changes are well-structured, introducing a clear mechanism for determining the cache layout (via environment variable or vLLM config) and applying it consistently.
Key improvements include:
- Robust handling of KV cache layout determination with user overrides.
- Centralized logic for stride order based on the selected layout.
- Consistent application of the layout across prefill, decode, and cascade attention wrappers.
A minor issue with error message formatting was found, which should be addressed. Overall, the PR is a good step towards optimizing KV cache transfers.
Summary of Findings
- Incorrect ValueError Formatting: In
vllm/v1/attention/backends/flashinfer.py
, theValueError
raised for an unknown cache layout uses incorrect string formatting. This would lead to aTypeError
if theValueError
itself is raised. It should be updated to use f-string formatting. - Enhanced Configuration Robustness: In
vllm/distributed/kv_transfer/kv_connector/utils.py
, an additional null check forkv_config
was added. This improves the robustness ofget_kv_connector_cache_layout
by preventing potentialAttributeError
ifkv_config
isNone
. - Flexible KV Cache Layout Configuration: The PR introduces the
FLASHINFER_KV_CACHE_LAYOUT
environment variable, allowing users to override the KV cache layout for FlashInfer. This falls back to the vLLM configuration if the environment variable is not set, providing good flexibility. - Efficient Layout Determination: The use of
functools.lru_cache
on theget_flashinfer_kv_cache_layout
function is a good optimization, preventing redundant computations of the cache layout within a single forward pass.
Merge Readiness
The pull request is generally in good shape and implements the intended functionality effectively. There is one medium-severity issue regarding ValueError
formatting that should be addressed before merging. Once this is fixed, the PR should be ready for merge. I am unable to approve the pull request myself; please ensure it is reviewed and approved by other maintainers.
elif cache_layout == "HND": | ||
stride_order = (0, 1, 3, 2, 4) | ||
else: | ||
raise ValueError("Unknown cache layout format %s.", cache_layout) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The string formatting for the ValueError
message appears to be incorrect. It seems like it's attempting to use %s
style formatting, but cache_layout
is not passed as an argument to the ValueError
constructor for formatting. This would result in a TypeError
if this ValueError
is actually raised, as ValueError
would try to interpret cache_layout
as a formatting argument.
Could you update this to use an f-string for clarity and correctness?
raise ValueError("Unknown cache layout format %s.", cache_layout) | |
raise ValueError(f"Unknown cache layout format {cache_layout}.") |
Signed-off-by: NickLucche <[email protected]>
Signed-off-by: NickLucche <[email protected]>
Did you do any performance tests to see how this impacts E2E performance for non-PD setups? I think a throughput benchmark is warranted at least |
@@ -28,10 +32,25 @@ | |||
from vllm.v1.worker.gpu_model_runner import GPUModelRunner | |||
|
|||
FLASHINFER_WORKSPACE_BUFFER_SIZE = 256 * 1024 * 1024 | |||
FLASHINFER_KV_CACHE_LAYOUT: str = os.getenv("FLASHINFER_KV_CACHE_LAYOUT", |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
this should be defined like all other env variables in vLLM. envs.py and VLLM_FLASHINFER_KV_CACHE_LAYOUT
I also think we should have the env variable be VLLM_KV_CACHE_LAYOUT
rather than having a specific one for each attention backend type.
I compared a few models (llama3/qwen3), didn't notice a huge impact tbh. I've added a simple H100 benchmark in the description. I should also probably tag @wenscarl to elaborate on HND for flashinfer. |
Signed-off-by: NickLucche <[email protected]>
This pull request has merge conflicts that must be resolved before it can be |
Follow up PR to #18775, again porting over functionality from V0 (ref #16605).
This PR will enable the use of FlashInfer with a HND cache layout in V1.
Among the most immediate benefits, this PR is a prerequisite to enabling heterogeneous TP support for disaggregated prefill-decode setup, optimizing the layout for xfers.
Test with:
A simple benchmark:
cc @mgoin